#[burn_tensor_testgen::testgen(ad_deform_conv2d)]
mod tests {
    use super::*;
    use burn_tensor::{module::deform_conv2d, ops::DeformConvOptions, Shape};

    #[test]
    fn test_deform_conv2d_basic() {
        let test = Conv2dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 3,
            kernel_size_1: 3,
            kernel_size_2: 3,
            padding_1: 0,
            padding_2: 0,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 1,
            offset_groups: 1,
            height: 4,
            width: 4,
        };
        let device = Default::default();
        let grads = Grads {
            x: TestTensor::from_floats(
                [[
                    [
                        [0.0000, 6.0678, 14.2071, 12.2477],
                        [11.2292, 33.7937, 50.1555, 44.0561],
                        [17.9294, 57.2174, 85.1505, 79.1840],
                        [18.0220, 73.6263, 126.8184, 151.6910],
                    ],
                    [
                        [0.0000, 8.9783, 20.7620, 17.7888],
                        [16.2326, 48.7386, 71.7961, 62.5845],
                        [25.3808, 80.5195, 119.0949, 110.0938],
                        [25.0567, 101.8461, 174.3329, 206.6013],
                    ],
                ]],
                &device,
            ),
            offset: TestTensor::from_floats(
                [[
                    [[0.0000, 15.0000], [30.0000, 45.0000]],
                    [[0.0000, 3.7500], [7.5000, 11.2500]],
                    [[62.6667, 78.3333], [94.0000, 109.6667]],
                    [[15.6667, 19.5833], [23.5000, 27.4167]],
                    [[130.6667, 104.1250], [163.3333, 122.2732]],
                    [[32.6667, -492.9583], [40.8333, -787.1620]],
                    [[204.0000, 221.0000], [238.0000, 255.0000]],
                    [[51.0000, 55.2500], [59.5000, 63.7500]],
                    [[282.6667, 300.3333], [318.0000, 335.6667]],
                    [[70.6667, 75.0833], [79.5000, 83.9167]],
                    [[366.6667, 144.3750], [403.3333, 146.4121]],
                    [[91.6667, -1788.9860], [100.8333, -2392.7456]],
                    [[456.0000, 475.0000], [-2718.6250, -2953.2188]],
                    [[114.0000, 118.7500], [37.7361, 37.4063]],
                    [[550.6667, 570.3334], [-3404.5139, -3672.5312]],
                    [[137.6667, 142.5833], [28.6806, 27.5197]],
                    [[650.6667, 27.9584], [-4174.3657, -59.7509]],
                    [[162.6667, -3991.0139], [14.4028, -298.7557]],
                ]],
                &device,
            ),
            weight: TestTensor::from_floats(
                [
                    [
                        [
                            [0.7029, 2.8356, 5.1067],
                            [12.7492, 19.4745, 17.8345],
                            [22.0687, 25.9156, 14.6394],
                        ],
                        [
                            [3.3696, 12.6134, 19.2671],
                            [36.7492, 50.5856, 43.5506],
                            [50.8774, 56.3292, 30.7470],
                        ],
                    ],
                    [
                        [
                            [0.7029, 2.8356, 5.1067],
                            [12.7492, 19.4745, 17.8345],
                            [22.0687, 25.9156, 14.6394],
                        ],
                        [
                            [3.3696, 12.6134, 19.2671],
                            [36.7492, 50.5856, 43.5506],
                            [50.8774, 56.3292, 30.7470],
                        ],
                    ],
                    [
                        [
                            [0.7029, 2.8356, 5.1067],
                            [12.7492, 19.4745, 17.8345],
                            [22.0687, 25.9156, 14.6394],
                        ],
                        [
                            [3.3696, 12.6134, 19.2671],
                            [36.7492, 50.5856, 43.5506],
                            [50.8774, 56.3292, 30.7470],
                        ],
                    ],
                ],
                &device,
            ),
            mask: TestTensor::from_floats(
                [[
                    [[1303.5000, 1447.8750], [1862.2500, 2006.6250]],
                    [[1571.1666, 1721.9581], [2154.7500, 2305.5417]],
                    [[1857.4999, 1396.7151], [2465.9167, 1753.2246]],
                    [[2315.5000, 2479.1250], [2948.7502, 3112.3750]],
                    [[2645.1665, 2815.2085], [3303.2500, 3473.2917]],
                    [[2993.5000, 1150.0625], [3676.4165, 1300.4055]],
                    [[3531.5000, 3714.3752], [1150.1876, 1148.4744]],
                    [[3923.1665, 4112.4585], [794.3865, 770.0470]],
                    [[4333.5000, 181.4101], [368.3260, 4.2679]],
                ]],
                &device,
            ),
            bias: TestTensor::from_floats([4., 4., 4.], &device),
        };
        test.assert_grads(grads);
    }

    #[test]
    fn test_deform_conv2d_batched() {
        let test = Conv2dTestCase {
            batch_size: 2,
            channels_in: 2,
            channels_out: 3,
            kernel_size_1: 3,
            kernel_size_2: 3,
            padding_1: 0,
            padding_2: 0,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 1,
            offset_groups: 1,
            height: 4,
            width: 4,
        };
        let device = Default::default();
        let grads = Grads {
            x: TestTensor::from_floats(
                [
                    [
                        [
                            [0.0000, 3.4604, 8.7539, 6.8080],
                            [8.4661, 24.0784, 35.4610, 26.4276],
                            [19.5988, 51.0406, 68.4389, 53.4993],
                            [17.4698, 47.9106, 67.3808, 56.6063],
                        ],
                        [
                            [0.0000, 5.1185, 12.7803, 9.8796],
                            [12.1957, 34.5728, 50.4616, 37.3777],
                            [27.4521, 71.1227, 94.5778, 73.4724],
                            [24.1147, 65.8443, 91.8995, 76.7475],
                        ],
                    ],
                    [
                        [
                            [6.3750, 19.3553, 26.4935, 22.5650],
                            [17.0026, 57.8088, 85.5580, 78.0746],
                            [20.7334, 86.5793, 139.4667, 136.4133],
                            [16.8126, 103.0225, 186.4502, 206.9613],
                        ],
                        [
                            [9.5625, 28.8786, 39.1137, 32.9178],
                            [25.1984, 85.0747, 124.6941, 112.5691],
                            [30.0242, 124.2863, 198.6056, 192.4489],
                            [23.5826, 143.4660, 257.8752, 283.2587],
                        ],
                    ],
                ],
                &device,
            ),
            offset: TestTensor::from_floats(
                [
                    [
                        [[0.0000, 7.5000], [15.0000, 22.5000]],
                        [[0.0000, 1.8750], [3.7500, 5.6250]],
                        [[31.3333, 39.1667], [47.0000, 54.8333]],
                        [[7.8333, 9.7917], [11.7500, 13.7083]],
                        [[65.3333, 62.7813], [81.6667, 75.4849]],
                        [[16.3333, -237.8021], [20.4167, -381.7280]],
                        [[102.0000, 110.5000], [119.0000, 127.5000]],
                        [[25.5000, 27.6250], [29.7500, 31.8750]],
                        [[141.3333, 150.1667], [159.0000, 167.8333]],
                        [[35.3333, 37.5417], [39.7500, 41.9583]],
                        [[183.3333, 132.3438], [201.6667, 142.0197]],
                        [[45.8333, -839.6840], [50.4167, -1133.4155]],
                        [[228.0000, 237.5000], [-1336.1562, -1452.1173]],
                        [[57.0000, 59.3750], [40.3090, 41.4141]],
                        [[275.3333, 285.1667], [-1670.5034, -1802.9244]],
                        [[68.8333, 71.2917], [44.0451, 44.9841]],
                        [[325.3333, 174.7396], [-2045.1747, -1090.4585]],
                        [[81.3333, -1844.0659], [46.8090, -1150.2101]],
                    ],
                    [
                        [[270.0000, 277.5000], [285.0000, 292.5000]],
                        [[67.5000, 69.3750], [71.2500, 73.1250]],
                        [[313.3333, 321.1667], [329.0000, 336.8333]],
                        [[78.3333, 80.2917], [82.2500, 84.2083]],
                        [[359.3333, 130.1563], [375.6667, 130.6099]],
                        [[89.8333, -4312.7603], [93.9167, -4893.6035]],
                        [[408.0000, 416.5000], [425.0000, 433.5000]],
                        [[102.0000, 104.1250], [106.2500, 108.3750]],
                        [[459.3333, 468.1667], [477.0000, 485.8333]],
                        [[114.8333, 117.0417], [119.2500, 121.4583]],
                        [[513.3334, 97.9688], [531.6667, 93.8947]],
                        [[128.3333, -6720.3926], [132.9167, -7504.5405]],
                        [[570.0000, 579.5000], [-7971.8438, -8251.0850]],
                        [[142.5000, 144.8750], [22.4965, 21.8203]],
                        [[629.3333, 639.1667], [-8948.2334, -9249.6641]],
                        [[157.3333, 159.7917], [15.7743, 14.8695]],
                        [[691.3333, 14.6145], [-9992.9453, -70.4040]],
                        [[172.8333, -9818.5234], [7.4132, -352.0222]],
                    ],
                ],
                &device,
            ),
            weight: TestTensor::from_floats(
                [
                    [
                        [
                            [77.7195, 89.8692, 69.0213],
                            [121.0760, 137.0775, 92.2989],
                            [100.0212, 106.5561, 61.1851],
                        ],
                        [
                            [112.3862, 131.6470, 103.8793],
                            [177.0760, 200.1887, 138.2681],
                            [149.5922, 158.7074, 94.3991],
                        ],
                    ],
                    [
                        [
                            [77.7195, 89.8692, 69.0213],
                            [121.0760, 137.0775, 92.2989],
                            [100.0212, 106.5561, 61.1851],
                        ],
                        [
                            [112.3862, 131.6470, 103.8793],
                            [177.0760, 200.1887, 138.2681],
                            [149.5922, 158.7074, 94.3991],
                        ],
                    ],
                    [
                        [
                            [77.7195, 89.8692, 69.0213],
                            [121.0760, 137.0775, 92.2989],
                            [100.0212, 106.5561, 61.1851],
                        ],
                        [
                            [112.3862, 131.6470, 103.8793],
                            [177.0760, 200.1887, 138.2681],
                            [149.5922, 158.7074, 94.3991],
                        ],
                    ],
                ],
                &device,
            ),
            mask: TestTensor::from_floats(
                [
                    [
                        [[1299.7499, 1439.4375], [1849.1249, 1988.8125]],
                        [[1528.0834, 1673.9791], [2101.8750, 2247.7708]],
                        [[1771.7500, 1624.9811], [2369.9583, 2099.5039]],
                        [[2183.7500, 2342.0625], [2806.3750, 2964.6875]],
                        [[2464.0833, 2628.6042], [3111.1250, 3275.6458]],
                        [[2759.7500, 1979.2551], [3431.2085, 2390.0286]],
                        [[3241.7498, 3418.6873], [2415.3589, 2500.8682]],
                        [[3574.0835, 3757.2292], [2394.3889, 2471.7510]],
                        [[3921.7500, 2095.5293], [2345.9363, 1199.5048]],
                    ],
                    [
                        [[5957.2500, 6096.9375], [6506.6250, 6646.3125]],
                        [[6392.5835, 6538.4790], [6966.3750, 7112.2705]],
                        [[6843.2500, 2443.8982], [7441.4585, 2550.9199]],
                        [[7462.2505, 7620.5625], [8084.8745, 8243.1875]],
                        [[7949.5835, 8114.1045], [8596.6250, 8761.1465]],
                        [[8452.2500, 1591.6719], [9123.7080, 1589.9454]],
                        [[9141.2500, 9318.1875], [1414.3584, 1375.1803]],
                        [[9680.5840, 9863.7285], [949.0560, 897.3544]],
                        [[10235.2500, 213.4454], [428.2699, 2.4790]],
                    ],
                ],
                &device,
            ),
            bias: TestTensor::from_floats([8., 8., 8.], &device),
        };
        test.assert_grads(grads);
    }

    #[test]
    fn test_deform_conv2d_different_kernel_size() {
        let test = Conv2dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 2,
            kernel_size_1: 3,
            kernel_size_2: 4,
            padding_1: 1,
            padding_2: 1,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 1,
            offset_groups: 1,
            height: 4,
            width: 4,
        };
        let device = Default::default();
        let grads = Grads {
            x: TestTensor::from_floats(
                [[
                    [
                        [14.5585, 27.2496, 37.3820, 36.0394],
                        [33.1519, 60.4807, 81.2647, 78.6182],
                        [57.5201, 108.6233, 153.4136, 170.0730],
                        [54.7062, 102.5967, 144.3672, 162.6436],
                    ],
                    [
                        [25.8364, 48.0884, 65.2492, 62.1033],
                        [56.8052, 102.9956, 136.9831, 131.1209],
                        [96.1054, 179.7902, 250.5509, 272.6688],
                        [90.2110, 167.5679, 232.8473, 257.9347],
                    ],
                ]],
                &device,
            ),
            offset: TestTensor::from_floats(
                [[
                    [
                        [0.0000, 5.3559, 11.7153],
                        [0.3125, 8.0000, 10.0000],
                        [0.7500, 14.0000, 16.0000],
                        [1.3125, 20.0000, 22.0000],
                    ],
                    [
                        [0.0000, 0.0017, 0.0069],
                        [16.0625, 2.0000, 2.5000],
                        [44.2500, 3.5000, 4.0000],
                        [84.5625, 5.0000, 5.5000],
                    ],
                    [
                        [67.4583, 79.9648, 93.5305],
                        [31.6667, 33.7778, 35.8889],
                        [38.0000, 40.1111, 42.2222],
                        [44.3333, 46.4444, 48.5556],
                    ],
                    [
                        [0.5278, 0.5956, 0.6671],
                        [7.9167, 8.4444, 8.9722],
                        [9.5000, 10.0278, 10.5556],
                        [11.0833, 11.6111, 12.1389],
                    ],
                    [
                        [154.7778, 175.1640, 151.8874],
                        [60.0000, 62.2222, 49.8997],
                        [66.6667, 68.8889, 54.3210],
                        [73.3333, 75.5555, 58.6034],
                    ],
                    [
                        [2.2222, 2.3630, -33.6034],
                        [15.0000, 15.5556, -227.7485],
                        [16.6667, 17.2222, -323.1605],
                        [18.3333, 18.8889, -432.0448],
                    ],
                    [
                        [264.1250, 202.1189, 0.0000],
                        [91.0000, 64.8148, 0.0000],
                        [98.0000, 68.6308, 0.0000],
                        [105.0000, 72.3009, 0.0000],
                    ],
                    [
                        [5.2500, -72.6832, 0.0000],
                        [22.7500, -334.6296, 0.0000],
                        [24.5000, -461.1053, 0.0000],
                        [26.2500, -601.7269, 0.0000],
                    ],
                    [
                        [44.0000, 119.7778, 122.2222],
                        [48.0486, 127.1111, 129.5556],
                        [52.2500, 134.4444, 136.8889],
                        [-313.8958, -800.7446, -850.7313],
                    ],
                    [
                        [337.7778, 29.9444, 30.5556],
                        [484.8542, 31.7778, 32.3889],
                        [646.7500, 33.6111, 34.2222],
                        [490.9653, 22.3989, 22.6599],
                    ],
                    [
                        [153.3333, 155.8889, 158.4444],
                        [161.0000, 163.5556, 166.1111],
                        [168.6667, 171.2222, 173.7778],
                        [-995.2491, -1054.5505, -1115.1342],
                    ],
                    [
                        [38.3333, 38.9722, 39.6111],
                        [40.2500, 40.8889, 41.5278],
                        [42.1667, 42.8056, 43.4444],
                        [24.3377, 24.5351, 24.7281],
                    ],
                    [
                        [192.0000, 194.6667, 89.0741],
                        [200.0000, 202.6667, 90.5463],
                        [208.0000, 210.6667, 91.8519],
                        [-1272.9375, -1343.5092, -581.1921],
                    ],
                    [
                        [48.0000, 48.6667, -741.3703],
                        [50.0000, 50.6667, -978.8981],
                        [52.0000, 52.6667, -1232.5927],
                        [25.3125, 25.4352, -638.8311],
                    ],
                    [
                        [233.3333, 87.7218, 0.0000],
                        [241.6667, 88.2716, 0.0000],
                        [250.0000, 88.6478, 0.0000],
                        [-1587.2161, -553.5372, 0.0000],
                    ],
                    [
                        [58.3333, -901.1902, 0.0000],
                        [60.4167, -1179.9877, 0.0000],
                        [62.5000, -1475.6252, 0.0000],
                        [24.8915, -621.3175, 0.0000],
                    ],
                    [
                        [196.4444, 280.2222, 283.1111],
                        [205.5625, 288.8889, 291.7778],
                        [-1173.4723, -1679.6113, -1771.2903],
                        [0.0000, 0.0000, 0.0000],
                    ],
                    [
                        [1144.8890, 70.0556, 70.7778],
                        [1469.6459, 72.2222, 72.9444],
                        [502.9167, 22.9882, 22.9506],
                        [0.0000, 0.0000, 0.0000],
                    ],
                    [
                        [324.0000, 327.0000, 330.0000],
                        [333.0000, 336.0000, 339.0000],
                        [-1931.4688, -2034.9608, -2139.9585],
                        [0.0000, 0.0000, 0.0000],
                    ],
                    [
                        [81.0000, 81.7500, 82.5000],
                        [83.2500, 84.0000, 84.7500],
                        [19.5938, 19.4661, 19.3333],
                        [0.0000, 0.0000, 0.0000],
                    ],
                    [
                        [373.3333, 376.4445, 44.8087],
                        [382.6667, 385.7778, 41.8596],
                        [-2313.7917, -2431.2759, -239.2101],
                        [0.0000, 0.0000, 0.0000],
                    ],
                    [
                        [93.3333, 94.1111, -1904.9321],
                        [95.6667, 96.4444, -2344.7146],
                        [14.2917, 14.0621, -341.7283],
                        [0.0000, 0.0000, 0.0000],
                    ],
                    [
                        [425.3333, 16.3684, 0.0000],
                        [435.0000, 12.1728, 0.0000],
                        [-2738.5173, -47.9289, 0.0000],
                        [0.0000, 0.0000, 0.0000],
                    ],
                    [
                        [106.3333, -2178.7473, 0.0000],
                        [108.7500, -2670.6790, 0.0000],
                        [6.9479, -162.9574, 0.0000],
                        [0.0000, 0.0000, 0.0000],
                    ],
                ]],
                &device,
            ),
            weight: TestTensor::from_floats(
                [
                    [
                        [
                            [1.8560, 7.2034, 12.8334, 11.9694],
                            [24.2368, 40.1255, 41.3964, 27.6420],
                            [43.6131, 57.5089, 46.0933, 25.1744],
                        ],
                        [
                            [6.9899, 26.5803, 42.6186, 37.5014],
                            [75.6232, 116.9257, 113.2884, 72.5678],
                            [112.7249, 139.8264, 107.6534, 56.7994],
                        ],
                    ],
                    [
                        [
                            [1.8560, 7.2034, 12.8334, 11.9694],
                            [24.2368, 40.1255, 41.3964, 27.6420],
                            [43.6131, 57.5089, 46.0933, 25.1744],
                        ],
                        [
                            [6.9899, 26.5803, 42.6186, 37.5014],
                            [75.6232, 116.9257, 113.2884, 72.5678],
                            [112.7249, 139.8264, 107.6534, 56.7994],
                        ],
                    ],
                ],
                &device,
            ),
            mask: TestTensor::from_floats(
                [[
                    [
                        [0.0000, 2.6779, 5.8576],
                        [40.1562, 775.9999, 849.2499],
                        [66.3750, 1067.7499, 1140.9999],
                        [98.6563, 1359.5000, 1432.7499],
                    ],
                    [
                        [67.4583, 76.8892, 86.8497],
                        [838.7916, 916.1111, 993.4306],
                        [1146.7500, 1224.0695, 1301.3889],
                        [1454.7083, 1532.0278, 1609.3472],
                    ],
                    [
                        [154.7778, 171.6607, 146.0455],
                        [986.1667, 1067.5555, 875.6536],
                        [1310.3333, 1391.7222, 1110.8640],
                        [1634.5001, 1715.8888, 1339.3390],
                    ],
                    [
                        [264.1250, 199.3876, 0.0000],
                        [1144.8751, 836.5740, 0.0000],
                        [1485.2499, 1056.2528, 0.0000],
                        [1825.6250, 1268.8589, 0.0000],
                    ],
                    [
                        [380.0000, 1047.8611, 1137.3889],
                        [527.6354, 1404.4445, 1493.9722],
                        [682.6807, 1761.0276, 1850.5554],
                        [503.8855, 1256.3406, 1304.9355],
                    ],
                    [
                        [1123.5000, 1217.0972, 1310.6943],
                        [1496.2917, 1589.8889, 1683.4861],
                        [1869.0834, 1962.6805, 2056.2778],
                        [1146.6998, 1190.1357, 1232.9299],
                    ],
                    [
                        [1300.0001, 1397.6667, 651.2036],
                        [1689.0000, 1786.6667, 807.2734],
                        [2078.0000, 2175.6667, 955.2593],
                        [1060.7812, 1097.7451, 465.6539],
                    ],
                    [
                        [1487.8334, 567.2195, 0.0000],
                        [1893.0416, 697.2655, 0.0000],
                        [2298.2500, 818.8910, 0.0000],
                        [947.2098, 323.8781, 0.0000],
                    ],
                    [
                        [1216.4445, 1792.8055, 1898.6112],
                        [1536.4478, 2214.2222, 2320.0278],
                        [517.7084, 725.6571, 749.3920],
                        [0.0000, 0.0000, 0.0000],
                    ],
                    [
                        [1897.5000, 2007.3749, 2117.2500],
                        [2335.1250, 2445.0000, 2554.8750],
                        [559.1096, 575.0975, 590.3336],
                        [0.0000, 0.0000, 0.0000],
                    ],
                    [
                        [2119.3333, 2233.2776, 265.4414],
                        [2573.1667, 2687.1111, 290.7444],
                        [385.6317, 392.4502, 37.3766],
                        [0.0000, 0.0000, 0.0000],
                    ],
                    [
                        [2352.5000, 90.0985, 0.0000],
                        [2822.5415, 78.5491, 0.0000],
                        [178.5990, 2.9309, 0.0000],
                        [0.0000, 0.0000, 0.0000],
                    ],
                ]],
                &device,
            ),
            bias: TestTensor::from_floats([12., 12.], &device),
        };
        test.assert_grads(grads);
    }

    #[test]
    fn test_deform_conv2d_different_padding() {
        let test = Conv2dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 2,
            kernel_size_1: 3,
            kernel_size_2: 3,
            padding_1: 2,
            padding_2: 3,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 1,
            offset_groups: 1,
            height: 4,
            width: 4,
        };
        let device = Default::default();
        let grads = Grads {
            x: TestTensor::from_floats(
                [[
                    [
                        [60.6330, 60.9065, 61.1795, 61.4520],
                        [122.5578, 123.0882, 123.6186, 124.1490],
                        [126.8011, 127.3315, 127.8619, 128.3924],
                        [131.0444, 131.5749, 132.1053, 132.6357],
                    ],
                    [
                        [102.0006, 102.4976, 102.9938, 103.4893],
                        [198.9330, 199.8306, 200.7282, 201.6259],
                        [206.1140, 207.0117, 207.9092, 208.8069],
                        [213.2949, 214.1926, 215.0903, 215.9879],
                    ],
                ]],
                &device,
            ),
            offset: TestTensor::from_floats(
                [[
                    [
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.8951, 14.7606, 17.6042, 20.6981, 22.2004, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.6875, 9.5000, 10.0000, 10.5000, 10.1088, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 1.1134, 13.5000, 14.0000, 14.5000, 13.6458, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 1.6134, 17.5000, 18.0000, 18.5000, 17.1088, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, -12.3958, -122.3994, -130.7523, -139.3555, -131.5268,
                            0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.1543, 0.0175, 0.0208, 0.0245, -0.3875, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 24.1875, 2.3750, 2.5000, 2.6250, -37.8634, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 48.0579, 3.3750, 3.5000, 3.6250, -66.7708, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 80.0023, 4.3750, 4.5000, 4.6250, -103.7523, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 113.2153, 5.1075, 5.2199, 5.3320, -139.7259, 0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 14.2060, 83.0176, 92.3794, 102.0100, 90.3563, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 6.5047, 35.4444, 35.9815, 36.5185, 29.9790, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 7.6683, 39.7407, 40.2778, 40.8148, 33.0719, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 8.9115, 44.0370, 44.5741, 45.1111, 36.0853, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, -57.5230, -274.2679, -289.5471, -305.0951, -248.5786, 0.0000,
                            0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 9.7492, 0.9554, 0.9810, 1.0069, -13.9305, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 96.0469, 8.8611, 8.9954, 9.1296, -129.9207, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 147.4348, 9.9352, 10.0694, 10.2037, -186.7187, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 207.4948, 11.0093, 11.1435, 11.2778, -252.1889, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 226.0500, 10.1534, 10.2520, 10.3504, -266.2553, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            44.2250, 159.8985, 176.6519, 193.6927, 146.2708, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            19.0508, 64.8704, 65.4444, 66.0185, 46.5532, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            21.0494, 69.4630, 70.0370, 70.6111, 49.1046, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            23.1331, 74.0556, 74.6296, 75.2037, 51.5710, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            -141.2003, -445.3022, -468.3810, -491.7472, -341.5531, 0.0000, 0.0000,
                            0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            35.6653, 3.5057, 3.5567, 3.6081, -48.7569, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            181.4047, 16.2176, 16.3611, 16.5046, -238.1361, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            263.8889, 17.3657, 17.5093, 17.6528, -326.4037, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            355.6433, 18.5139, 18.6574, 18.8009, -423.9413, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            318.7092, 14.3597, 14.4416, 14.5231, -369.8195, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 0.0000, 88.8467, 237.4784, 261.7312, 286.2899, 182.5087, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 37.6880, 94.7222, 95.3333, 95.9445, 57.4416, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 40.5625, 99.6111, 100.2222, 100.8333, 59.4107, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 43.5275, 104.5000, 105.1111, 105.7222, 61.2893, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, -258.3244, -618.3539, -649.3403, -680.6325, -397.1010,
                            0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 0.0000, 76.2294, 7.5641, 7.6417, 7.7197, -102.7922, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 272.0152, 23.6806, 23.8333, 23.9861, -351.9442, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 386.0625, 24.9028, 25.0556, 25.2083, -472.1479, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 509.9781, 26.1250, 26.2778, 26.4306, -602.2200, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 378.4102, 17.1237, 17.1875, 17.2510, -436.0007, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 157.6233, 331.9385, 365.2834, 398.9526, 205.9885, 0.0000,
                            0.0000,
                        ],
                        [
                            0.0000, 66.4959, 130.9259, 131.5741, 132.2222, 64.4360, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 70.3968, 136.1111, 136.7593, 137.4074, 65.6723, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 74.3938, 141.2963, 141.9444, 142.5926, 66.8125, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, -432.7980, -827.4919, -867.9785, -908.7894, -425.0742, 0.0000,
                            0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 140.1500, 14.0440, 14.1529, 14.2623, -187.6569, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 386.8139, 32.7315, 32.8935, 33.0556, -494.7796, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 538.9267, 34.0278, 34.1898, 34.3519, -653.4219, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 701.5059, 35.3241, 35.4861, 35.6482, -822.5306, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 416.0446, 18.9036, 18.9446, 18.9853, -476.7288, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            249.8765, 435.8685, 479.1788, 522.8320, 207.9198, 0.0000, 0.0000,
                            0.0000,
                        ],
                        [
                            105.4170, 170.6111, 171.2963, 171.9815, 64.7500, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            110.4417, 176.0926, 176.7778, 177.4630, 65.1560, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            115.5679, 181.5741, 182.2592, 182.9445, 65.4606, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            -662.7435, -1056.6418, -1107.5020, -1158.7047, -409.5102, 0.0000,
                            0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            227.1605, 22.9825, 23.1258, 23.2695, -303.1120, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            518.4952, 42.6528, 42.8241, 42.9954, -657.1573, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            712.2524, 44.0231, 44.1944, 44.3657, -857.8172, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            917.0740, 45.3935, 45.5648, 45.7361, -1069.5416, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            416.5815, 18.9978, 19.0131, 19.0280, -475.0315, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 0.0000, 151.7503, 210.1667, 210.8889, 211.6111, 57.5069, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 157.9293, 215.9444, 216.6667, 217.3889, 57.0522, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 164.2153, 221.7222, 222.4445, 223.1667, 56.4905, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, -931.7838, -1285.3538, -1346.5559, -1408.1194,
                            -346.7390, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 0.0000, 655.6700, 52.5417, 52.7222, 52.9028, -824.9468, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 890.9725, 53.9861, 54.1667, 54.3472, -1067.5250, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 1137.9375, 55.4306, 55.6111, 55.7917, -1321.7656,
                            0.0000,
                        ],
                        [
                            0.0000, 0.0000, 375.5806, 17.1810, 17.1695, 17.1576, -425.9937, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 213.5215, 256.6296, 257.3889, 258.1481, 41.6529, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 221.0156, 262.7037, 263.4630, 264.2222, 40.1766, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 228.6223, 268.7778, 269.5370, 270.2963, 38.5878, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, -1285.4667, -1554.2545, -1627.5306, -1701.1866, -228.2914,
                            0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 823.3806, 64.1574, 64.3472, 64.5370, -1028.5327, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 1107.2965, 65.6759, 65.8657, 66.0556, -1320.0975, 0.0000,
                            0.0000,
                        ],
                        [
                            0.0000, 1403.4730, 67.1944, 67.3843, 67.5741, -1623.9230, 0.0000,
                            0.0000,
                        ],
                        [
                            0.0000, 288.1514, 13.2018, 13.1585, 13.1148, -323.5778, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            288.7901, 306.5741, 307.3704, 308.1667, 15.7342, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            297.6968, 312.9444, 313.7407, 314.5370, 13.1389, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            306.7215, 319.3148, 320.1111, 320.9074, 10.4255, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            -1711.5431, -1844.0131, -1930.2366, -2016.8586, -46.8461, 0.0000,
                            0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            1011.3581, 76.6435, 76.8426, 77.0417, -1255.0455, 0.0000, 0.0000,
                            0.0000,
                        ],
                        [
                            1347.4663, 78.2361, 78.4352, 78.6343, -1599.1759, 0.0000, 0.0000,
                            0.0000,
                        ],
                        [
                            1696.4333, 79.8287, 80.0278, 80.2269, -1956.1649, 0.0000, 0.0000,
                            0.0000,
                        ],
                        [
                            146.7036, 6.6909, 6.6128, 6.5342, -159.2772, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                ]],
                &device,
            ),
            weight: TestTensor::from_floats(
                [
                    [
                        [
                            [10.3420, 22.9881, 35.6342],
                            [46.9202, 59.5663, 72.2124],
                            [80.8816, 92.5915, 104.1585],
                        ],
                        [
                            [29.2134, 68.8378, 108.4622],
                            [143.8251, 183.4495, 223.0739],
                            [228.0294, 256.7517, 283.8071],
                        ],
                    ],
                    [
                        [
                            [10.3420, 22.9881, 35.6342],
                            [46.9202, 59.5663, 72.2124],
                            [80.8816, 92.5915, 104.1585],
                        ],
                        [
                            [29.2134, 68.8378, 108.4622],
                            [143.8251, 183.4495, 223.0739],
                            [228.0294, 256.7517, 283.8071],
                        ],
                    ],
                ],
                &device,
            ),
            mask: TestTensor::from_floats(
                [[
                    [
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.4475, 7.3803, 8.8021, 10.3490, 11.1002, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 44.3438, 584.9374, 639.2500, 693.5624, 683.2628, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 68.3901, 803.4375, 857.7500, 912.0625, 874.6981, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 96.4734, 1021.9375, 1076.2500, 1130.5625, 1062.0959,
                            0.0000,
                        ],
                        [
                            0.0000, 0.0000, 121.3021, 1168.4879, 1218.3738, 1268.1349, 1169.4447,
                            0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 13.0845, 75.8609, 83.7678, 91.8090, 80.7282, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 118.9504, 649.4861, 707.8218, 766.1574, 658.0766, 0.0000,
                            0.0000,
                        ],
                        [
                            0.0000, 170.6608, 884.1713, 942.5070, 1000.8427, 837.8093, 0.0000,
                            0.0000,
                        ],
                        [
                            0.0000, 226.7073, 1118.8564, 1177.1923, 1235.5277, 1013.2059, 0.0000,
                            0.0000,
                        ],
                        [
                            0.0000, 234.9397, 1106.2139, 1153.4156, 1200.4827, 966.2489, 0.0000,
                            0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            42.5240, 153.0457, 168.3193, 183.7365, 138.1447, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            207.3196, 718.4328, 780.7916, 843.1504, 619.9750, 0.0000, 0.0000,
                            0.0000,
                        ],
                        [
                            290.2778, 969.3032, 1031.6620, 1094.0208, 784.4216, 0.0000, 0.0000,
                            0.0000,
                        ],
                        [
                            377.8711, 1220.1736, 1282.5325, 1344.8912, 944.2330, 0.0000, 0.0000,
                            0.0000,
                        ],
                        [
                            328.0830, 1025.4950, 1069.1306, 1112.6222, 766.0549, 0.0000, 0.0000,
                            0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 0.0000, 88.2382, 235.0552, 258.1943, 281.4864, 178.8585, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 305.5755, 789.8680, 856.2500, 922.6319, 572.4661,
                            0.0000,
                        ],
                        [
                            0.0000, 0.0000, 421.8090, 1056.9236, 1123.3055, 1189.6875, 719.5988,
                            0.0000,
                        ],
                        [
                            0.0000, 0.0000, 542.9767, 1323.9792, 1390.3612, 1456.7430, 861.7973,
                            0.0000,
                        ],
                        [
                            0.0000, 0.0000, 393.2916, 934.4397, 974.0104, 1013.4281, 586.9240,
                            0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 157.2149, 330.2274, 362.4734, 394.8816, 203.3744, 0.0000,
                            0.0000,
                        ],
                        [
                            0.0000, 424.3406, 867.4954, 937.9005, 1008.3055, 505.6405, 0.0000,
                            0.0000,
                        ],
                        [
                            0.0000, 578.8949, 1150.7361, 1221.1412, 1291.5463, 630.4140, 0.0000,
                            0.0000,
                        ],
                        [
                            0.0000, 738.6825, 1433.9769, 1504.3820, 1574.7871, 749.9543, 0.0000,
                            0.0000,
                        ],
                        [
                            0.0000, 429.9127, 816.5075, 850.7720, 884.8738, 411.1526, 0.0000,
                            0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            249.8765, 434.9642, 477.1987, 519.6047, 206.2156, 0.0000, 0.0000,
                            0.0000,
                        ],
                        [
                            560.3093, 949.5208, 1023.9491, 1098.3773, 422.4583, 0.0000, 0.0000,
                            0.0000,
                        ],
                        [
                            756.7681, 1248.9468, 1323.3750, 1397.8032, 521.2890, 0.0000, 0.0000,
                            0.0000,
                        ],
                        [
                            958.7592, 1548.3728, 1622.8009, 1697.2292, 614.5874, 0.0000, 0.0000,
                            0.0000,
                        ],
                        [
                            428.8339, 679.2698, 707.3463, 735.2509, 258.1694, 0.0000, 0.0000,
                            0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 0.0000, 707.6714, 1033.6875, 1112.1389, 1190.5902, 328.2950,
                            0.0000,
                        ],
                        [
                            0.0000, 0.0000, 947.7794, 1349.2986, 1427.7500, 1506.2014, 399.4381,
                            0.0000,
                        ],
                        [
                            0.0000, 0.0000, 1193.7189, 1664.9097, 1743.3611, 1821.8125, 464.7498,
                            0.0000,
                        ],
                        [
                            0.0000, 0.0000, 388.7379, 532.5035, 553.9629, 575.2411, 140.6583,
                            0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            0.0000, 880.7972, 1124.3936, 1206.8680, 1289.3427, 209.6276, 0.0000,
                            0.0000,
                        ],
                        [
                            0.0000, 1169.8828, 1456.1898, 1538.6644, 1621.1389, 247.7547, 0.0000,
                            0.0000,
                        ],
                        [
                            0.0000, 1465.0988, 1787.9861, 1870.4606, 1952.9352, 279.7515, 0.0000,
                            0.0000,
                        ],
                        [
                            0.0000, 297.3307, 356.3622, 369.8935, 383.2343, 50.9746, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                    [
                        [
                            1074.5679, 1219.4977, 1305.9954, 1392.4930, 71.1624, 0.0000, 0.0000,
                            0.0000,
                        ],
                        [
                            1416.2147, 1567.4791, 1653.9769, 1740.4746, 72.6899, 0.0000, 0.0000,
                            0.0000,
                        ],
                        [
                            1764.2908, 1915.4606, 2001.9585, 2088.4561, 67.7876, 0.0000, 0.0000,
                            0.0000,
                        ],
                        [
                            151.0184, 160.0550, 164.7761, 169.2984, 3.8659, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                        [
                            0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                        ],
                    ],
                ]],
                &device,
            ),
            bias: TestTensor::from_floats([48., 48.], &device),
        };
        test.assert_grads(grads);
    }

    struct Conv2dTestCase {
        batch_size: usize,
        channels_in: usize,
        channels_out: usize,
        kernel_size_1: usize,
        kernel_size_2: usize,
        padding_1: usize,
        padding_2: usize,
        stride_1: usize,
        stride_2: usize,
        dilation_1: usize,
        dilation_2: usize,
        groups: usize,
        offset_groups: usize,
        height: usize,
        width: usize,
    }

    struct Grads {
        x: TestTensor<4>,
        offset: TestTensor<4>,
        weight: TestTensor<4>,
        mask: TestTensor<4>,
        bias: TestTensor<1>,
    }

    impl Conv2dTestCase {
        fn assert_grads(self, expected_grads: Grads) {
            let out_height =
                (self.height + 2 * self.padding_1 - self.dilation_1 * (self.kernel_size_1 - 1) - 1)
                    / self.stride_1
                    + 1;
            let out_width =
                (self.width + 2 * self.padding_2 - self.dilation_2 * (self.kernel_size_2 - 1) - 1)
                    / self.stride_2
                    + 1;

            let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
            let shape_offset = Shape::new([
                self.batch_size,
                2 * self.offset_groups * self.kernel_size_1 * self.kernel_size_2,
                out_height,
                out_width,
            ]);
            let shape_weight = Shape::new([
                self.channels_out,
                self.channels_in / self.groups,
                self.kernel_size_1,
                self.kernel_size_2,
            ]);
            let shape_mask = Shape::new([
                self.batch_size,
                self.offset_groups * self.kernel_size_1 * self.kernel_size_2,
                out_height,
                out_width,
            ]);
            let device = Default::default();
            let weight = TestAutodiffTensor::from_data(
                TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
                    .reshape::<4, _>(shape_weight)
                    .into_data(),
                &device,
            )
            .require_grad();
            let bias = TestAutodiffTensor::from_data(
                TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
                &device,
            )
            .require_grad();
            let x = TestAutodiffTensor::from_data(
                TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
                    .reshape::<4, _>(shape_x)
                    .into_data(),
                &device,
            )
            .require_grad();
            let offset = TestAutodiffTensor::from_data(
                TestTensorInt::arange(0..shape_offset.num_elements() as i64, &device)
                    .reshape::<4, _>(shape_offset.clone())
                    .into_data(),
                &device,
            )
            .div_scalar(shape_offset.num_elements() as f32)
            .require_grad();

            let mask = TestAutodiffTensor::from_data(
                TestTensorInt::arange(0..shape_mask.num_elements() as i64, &device)
                    .reshape::<4, _>(shape_mask.clone())
                    .into_data(),
                &device,
            )
            .div_scalar(shape_mask.num_elements() as f32)
            .require_grad();

            let output = deform_conv2d(
                x.clone(),
                offset.clone(),
                weight.clone(),
                Some(mask.clone()),
                Some(bias.clone()),
                DeformConvOptions::new(
                    [self.stride_1, self.stride_2],
                    [self.padding_1, self.padding_2],
                    [self.dilation_1, self.dilation_2],
                    self.groups,
                    self.offset_groups,
                ),
            );
            let grads = output.backward();

            // Assert
            let x_grad_actual = x.grad(&grads).unwrap();
            let offset_grad_actual = offset.grad(&grads).unwrap();
            let weight_grad_actual = weight.grad(&grads).unwrap();
            let mask_grad_actual = mask.grad(&grads).unwrap();
            let bias_grad_actual = bias.grad(&grads).unwrap();

            println!("Testing bias");
            expected_grads
                .bias
                .to_data()
                .assert_approx_eq(&bias_grad_actual.to_data(), 3);
            println!("Testing input");
            expected_grads
                .x
                .to_data()
                .assert_approx_eq(&x_grad_actual.to_data(), 3);
            println!("Testing offset");
            expected_grads
                .offset
                .to_data()
                .assert_approx_eq(&offset_grad_actual.to_data(), 3);
            println!("Testing mask");
            expected_grads
                .mask
                .to_data()
                .assert_approx_eq(&mask_grad_actual.to_data(), 3);
            println!("Testing weight");
            expected_grads
                .weight
                .to_data()
                .assert_approx_eq_diff(&weight_grad_actual.to_data(), 0.04);
        }
    }
}
